Skip-gram embedding (variant of Word2Vec)¶
  • Generate Skip-gram Embeddings using an open-source "gensim" model.
    • The create_skipgram_embeddings function uses Gensim's Word2Vec to train a skip-gram model. The sentences are tokenized and passed to the model, which generates word embeddings.
    • For each word, we store the corresponding vector in a dictionary word_embeddings.
  • Store the embeddings in a FAISS vector database for efficient similarity search.
    • The create_faiss_index function converts the embeddings into a matrix format and stores them in a FAISS index using faiss.IndexFlatL2.
    • FAISS is a library that helps with efficient similarity search and clustering.
  • Implement Retrieval-Augmented Generation (RAG) for querying the database based on user input and predicting the surrounding words.
    • The rag_query function takes a user query, tokenizes it, and queries the FAISS index for the closest word embeddings.
    • The function uses the FAISS search method to find the most similar words to each word in the query. It returns the indices and distances of the closest words.
  • Display Results:
    • The original embeddings are displayed for the word, its embedding, and its index.
    • The user query is displayed as plain text.
    • The RAG retrieved data is displayed that includes the query word, the predicted word (from FAISS search), the index of the predicted word, the distance from the query word to the predicted word, and the embedding of the predicted word.
    • Heatmap Visualization
      • This function creates a heatmap of the cosine similarities between a chosen word and its surrounding context words, based on the skip-gram model.
      • Cosine Similarity is calculated between the embedding of the target word and those of its context words, using cosine_similarity from sklearn.metrics.pairwise to compute the similarity between the target word and the surrounding context words.
In [ ]:
%pip install -q gensim faiss-cpu pandas matplotlib
%pip install -q prettytable
Note: you may need to restart the kernel using dbutils.library.restartPython() to use updated packages.
Note: you may need to restart the kernel using dbutils.library.restartPython() to use updated packages.
Note: you may need to restart the kernel using dbutils.library.restartPython() to use updated packages.
Note: you may need to restart the kernel using dbutils.library.restartPython() to use updated packages.
In [ ]:
import gensim
import faiss
import numpy as np
import pandas as pd
from prettytable import PrettyTable
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics.pairwise import cosine_similarity

# Step 1: Create Skip-Gram Embedding using Gensim
def create_skipgram_embeddings(sentences):
    # Tokenize the sentences and create a Word2Vec model (skip-gram)
    tokenized_sentences = [sentence.lower().split() for sentence in sentences]
    model = gensim.models.Word2Vec(sentences=tokenized_sentences, vector_size=100, window=5, sg=1, min_count=1, workers=4)

    # Extract word embeddings from the model
    word_embeddings = {word: model.wv[word] for word in model.wv.index_to_key}
    return model, word_embeddings

# Sample sentences for skip-gram model
sentences = [
    "The quick brown fox jumped over the lazy dog",
    "Artificial intelligence is the future of technology",
    "Natural language processing is a subset of AI",
    "Skip-gram model learns word representations from context",
    "Word embeddings are used to improve machine learning models"
]

# Create embeddings using skip-gram model
model, word_embeddings = create_skipgram_embeddings(sentences)

# Step 2: Store the embeddings in a FAISS index
def create_faiss_index(word_embeddings):
    embedding_matrix = np.array([embedding for embedding in word_embeddings.values()])
    
    # Initialize a FAISS index for L2 distance
    dimension = embedding_matrix.shape[1]  # Dimensionality of the embeddings
    index = faiss.IndexFlatL2(dimension)
    
    # Add the embeddings to the FAISS index
    index.add(embedding_matrix)
    
    return index

# Create FAISS index with the embeddings
faiss_index = create_faiss_index(word_embeddings)

# Step 3: Implement Retrieval-Augmented Generation (RAG) to query the database and get top k embeddings
def rag_query(user_query, model, index, k=2):
    query_tokens = user_query.lower().split()  # Tokenize the query
    results = []
    
    # For each word in the query, retrieve the closest k words from FAISS
    for token in query_tokens:
        if token in model.wv:
            query_embedding = model.wv[token].reshape(1, -1)  # Reshape for FAISS query
            distances, indices = index.search(query_embedding, k)  # Search the FAISS index for top k results
            results.append((token, indices[0], distances[0]))  # Collect word, indices, and distances
    
    return results

# Example user query
user_query = "future of technology"
results = rag_query(user_query, model, faiss_index, k=2)

# Step 4: Display results

# 4.1: Display the original word embeddings in a neat table format using PrettyTable
def display_original_embeddings(word_embeddings):
    # Create a PrettyTable for Original Word Embeddings with Index
    table = PrettyTable()
    table.field_names = ["Index", "Word", "Embedding"]

    for idx, word in enumerate(word_embeddings.keys()):
        table.add_row([idx, word, str(word_embeddings[word])[:50]])  # Shorten embedding for display

    return table

# 4.2: Display the retrieved data in a table format using PrettyTable
def display_retrieved_data(results, model):
    # Create a PrettyTable for Retrieved Data
    table = PrettyTable()
    table.field_names = ["Query Word", "Predicted Word", "Index", "Distance", "Embedding"]

    for word, indices, distances in results:
        for idx, dist in zip(indices, distances):
            retrieved_word = model.wv.index_to_key[idx]  # Get the word from FAISS index
            table.add_row([word, retrieved_word, idx, dist, str(model.wv[retrieved_word])[:50]])  # Shorten embedding for display

    return table

# 4.3: Display the heatmap of the retrieved words based on cosine similarity
def display_retrieval_heatmap(user_query, results, model):
    # Extract the embeddings of the query word and its top-k retrieved words
    query_words = [result[0] for result in results]  # Words from user query
    retrieved_words = []
    embeddings = []
    
    for word, indices, _ in results:
        for idx in indices:
            retrieved_word = model.wv.index_to_key[idx]
            retrieved_words.append(retrieved_word)
            embeddings.append(model.wv[retrieved_word])
    
    # Query word embeddings
    query_embeddings = np.array([model.wv[word] for word in query_words])
    
    # Calculate cosine similarity between the query and retrieved words
    similarity_matrix = cosine_similarity(query_embeddings, embeddings)

    # Create a DataFrame for better visualization
    similarity_df = pd.DataFrame(similarity_matrix, index=query_words, columns=retrieved_words)

    # Display the heatmap
    plt.figure(figsize=(8, 3.5))
    sns.heatmap(similarity_df, annot=True, cmap="YlGnBu", cbar=True, linewidths=0.5)
    plt.title(f"Cosine Similarity Heatmap for Query '{user_query}' and Retrieved Words")
    plt.show()

# Display the original word embeddings table
print("Original Word Embeddings (with Index):")
original_embeddings_table = display_original_embeddings(word_embeddings)
print(original_embeddings_table)

# Display original user query as text
print("\nOriginal User Query:", f"'{user_query}'")

# Display the retrieved words and their embeddings
print("\nRetrieved Data for User Query:", f"'{user_query}'")
retrieved_data_table = display_retrieved_data(results, model)
print(retrieved_data_table)

# Display the heatmap for the retrieved words
display_retrieval_heatmap(user_query, results, model)
Original Word Embeddings (with Index):
+-------+-----------------+----------------------------------------------------+
| Index |       Word      |                     Embedding                      |
+-------+-----------------+----------------------------------------------------+
|   0   |       the       | [-5.3701916e-04  2.3689654e-04  5.1034042e-03  9.0 |
|   1   |        of       | [-8.6196875e-03  3.6657380e-03  5.1898835e-03  5.7 |
|   2   |        is       | [ 9.89265172e-05  3.08311544e-03 -6.81140460e-03 - |
|   3   |       word      | [-8.2426788e-03  9.2993546e-03 -1.9766092e-04 -1.9 |
|   4   |    processing   | [-0.00713803  0.00124788 -0.00718535 -0.00223861   |
|   5   |     natural     | [-8.7241465e-03  2.1291552e-03 -8.6793368e-04 -9.3 |
|   6   |    technology   | [ 8.1322715e-03 -4.4573355e-03 -1.0683584e-03  1.0 |
|   7   |      future     | [ 8.1687383e-03 -4.4429051e-03  8.9858277e-03  8.2 |
|   8   |   intelligence  | [-9.5796995e-03  8.9406269e-03  4.1691624e-03  9.2 |
|   9   |    artificial   | [-0.0051557  -0.00666764 -0.00777601  0.00831251 - |
|   10  |       dog       | [ 7.0887972e-03 -1.5679311e-03  7.9474971e-03 -9.4 |
|   11  |       lazy      | [ 9.7717736e-03  8.1663514e-03  1.2811647e-03  5.0 |
|   12  |       over      | [-1.9447126e-03 -5.2659069e-03  9.4463034e-03 -9.2 |
|   13  |      jumped     | [-0.00950012  0.00956222 -0.00777076 -0.00264551 - |
|   14  |       fox       | [ 7.6962691e-03  9.1207959e-03  1.1379765e-03 -8.3 |
|   15  |      brown      | [-7.1894615e-03  4.2341282e-03  2.1635876e-03  7.4 |
|   16  |      quick      | [ 1.30016566e-03 -9.80430376e-03  4.58776252e-03 - |
|   17  |     language    | [ 0.00180023  0.00704609  0.0029447  -0.00698085   |
|   18  |      models     | [ 0.00973555 -0.00978038 -0.00649949  0.00278378   |
|   19  |     learning    | [ 5.6248982e-03  5.4965699e-03  1.8382617e-03  5.7 |
|   20  |      subset     | [ 0.0025627   0.00085199 -0.00254519  0.00936274   |
|   21  |        ai       | [ 1.3443247e-03  6.5492862e-03  9.9802474e-03  9.0 |
|   22  |    skip-gram    | [-2.3460490e-04  4.2256210e-03  2.1153577e-03  1.0 |
|   23  |      model      | [-0.00250877 -0.00590266  0.00748334 -0.00725973 - |
|   24  |      learns     | [-4.9735666e-03 -1.2833046e-03  3.2806373e-03 -6.4 |
|   25  | representations | [ 0.00965136  0.00732719  0.00125347 -0.0034065  - |
|   26  |       from      | [-6.9636083e-03 -2.4585128e-03 -8.0229379e-03  7.5 |
|   27  |     context     | [ 0.00211351  0.00573515 -0.00211641  0.0031723    |
|   28  |    embeddings   | [ 8.3545828e-03 -5.6522468e-04 -9.4374381e-03  4.7 |
|   29  |       are       | [-4.2830892e-03 -9.3249166e-03 -1.8744217e-03 -3.7 |
|   30  |       used      | [ 1.9264125e-04  2.1422671e-03  1.0679637e-03  7.5 |
|   31  |        to       | [-0.00219433 -0.00970297  0.00929525  0.00203881 - |
|   32  |     improve     | [ 6.4154314e-03 -8.9511415e-03 -7.3454739e-03 -1.7 |
|   33  |     machine     | [ 0.00480066 -0.00362838 -0.00426481  0.00121976 - |
|   34  |        a        | [-1.5039656e-03 -4.0219319e-03 -4.3986561e-03 -4.6 |
+-------+-----------------+----------------------------------------------------+

Original User Query: 'future of technology'

Retrieved Data for User Query: 'future of technology'
+------------+----------------+-------+--------------+----------------------------------------------------+
| Query Word | Predicted Word | Index |   Distance   |                     Embedding                      |
+------------+----------------+-------+--------------+----------------------------------------------------+
|   future   |     future     |   7   |     0.0      | [ 8.1687383e-03 -4.4429051e-03  8.9858277e-03  8.2 |
|   future   |    learning    |   19  | 0.004345501  | [ 5.6248982e-03  5.4965699e-03  1.8382617e-03  5.7 |
|     of     |       of       |   1   |     0.0      | [-8.6196875e-03  3.6657380e-03  5.1898835e-03  5.7 |
|     of     |   skip-gram    |   22  | 0.0054203393 | [-2.3460490e-04  4.2256210e-03  2.1153577e-03  1.0 |
| technology |   technology   |   6   |     0.0      | [ 8.1322715e-03 -4.4573355e-03 -1.0683584e-03  1.0 |
| technology |   embeddings   |   28  | 0.0046820673 | [ 8.3545828e-03 -5.6522468e-04 -9.4374381e-03  4.7 |
+------------+----------------+-------+--------------+----------------------------------------------------+
No description has been provided for this image